# import packages
import pandas as pd 
import numpy as np

# load in 2014 op population and subset to NC residents with ground truth race labels and weights
individual2014 = pd.read_csv('/REDACTED/data/final/individualBISG2014_full_final.csv')
ncdata = pd.read_csv("/REDACTED/NC_Analysis_Dataset_2014_tplevel.csv")
ncweights = pd.read_csv("/REDACTED/HertzGraff/nc_reweights_v4_April.csv")

#ncdata.black_ind.isnull().sum()
ncdata = ncdata[ncdata.black_ind.notna()]

nc_BISG = individual2014[individual2014['state']=="NC"]
nc_merged = pd.merge(ncdata, nc_BISG, on = 'taxpayer_id')

ncweights = ncweights[['taxpayer_id', 'uswgt']]
nc_final = pd.merge(nc_merged, ncweights, on = 'taxpayer_id')

#######################################
# write out disparity results
sys.stdout = open("/REDACTED/nc_ground_truth_disparity.txt", "w")

# Full population
full_pop_n = len(nc_final)

# construct weighted audit measure using NC weights
nc_final['aud_wgt'] = nc_final['aud_no_research_audits'] * nc_final['uswgt']

# separate into black and non-black datasets
black_nc = nc_final[nc_final.black_ind == 1]
nonblack_nc = nc_final[nc_final.black_ind == 0]

black_audit_rate = black_nc.aud_no_research_audits.mean()
nonblack_audit_rate = nonblack_nc.aud_no_research_audits.mean()

disparity_unwgt = black_audit_rate - nonblack_audit_rate

black_audit_rate_wgt = black_nc.aud_wgt.sum() / black_nc.uswgt.sum()
nonblack_audit_rate_wgt = nonblack_nc.aud_wgt.sum() / nonblack_nc.uswgt.sum()

disparity_wgt = black_audit_rate_wgt - nonblack_audit_rate_wgt

print("Full Population")
print("Unweighted: " + str(disparity_unwgt))
print("Weighted: " + str(disparity_wgt))
print("N: " + str(full_pop_n))

##############################################
# Within EITC
nc_final_eitc = nc_final[nc_final.isEIC == 1]

eitc_n = len(nc_final_eitc)

# separate into black and non-black datasets
black_nc_eitc = nc_final_eitc[nc_final_eitc.black_ind == 1]
nonblack_nc_eitc = nc_final_eitc[nc_final_eitc.black_ind == 0]

black_audit_rate_eitc = black_nc_eitc.aud_no_research_audits.mean()
nonblack_audit_rate_eitc = nonblack_nc_eitc.aud_no_research_audits.mean()

eitc_disparity_unwgt = black_audit_rate_eitc - nonblack_audit_rate_eitc

black_audit_rate_eitc_wgt = black_nc_eitc.aud_wgt.sum() / black_nc_eitc.uswgt.sum()
nonblack_audit_rate_eitc_wgt = nonblack_nc_eitc.aud_wgt.sum() / nonblack_nc_eitc.uswgt.sum()

eitc_disparity_wgt = black_audit_rate_eitc_wgt - nonblack_audit_rate_eitc_wgt

print("\nEITC")
print("Unweighted: " + str(eitc_disparity_unwgt))
print("Weighted: " + str(eitc_disparity_wgt))
print("N: " + str(eitc_n))

##################################################
# Outside EITC
nc_final_non_eitc = nc_final[nc_final.isEIC == 0]

non_eitc_n = len(nc_final_non_eitc)

# separate into black and non-black datasets
black_nc_non_eitc = nc_final_non_eitc[nc_final_non_eitc.black_ind == 1]
nonblack_nc_non_eitc = nc_final_non_eitc[nc_final_non_eitc.black_ind == 0]

black_audit_rate_non_eitc = black_nc_non_eitc.aud_no_research_audits.mean()
nonblack_audit_rate_non_eitc = nonblack_nc_non_eitc.aud_no_research_audits.mean()

non_eitc_disparity_unwgt = black_audit_rate_non_eitc - nonblack_audit_rate_non_eitc

black_audit_rate_non_eitc_wgt = black_nc_non_eitc.aud_wgt.sum() / black_nc_non_eitc.uswgt.sum()
nonblack_audit_rate_non_eitc_wgt = nonblack_nc_non_eitc.aud_wgt.sum() / nonblack_nc_non_eitc.uswgt.sum()

non_eitc_disparity_wgt = black_audit_rate_non_eitc_wgt - nonblack_audit_rate_non_eitc_wgt

print("\nNon-EITC")
print("Unweighted: " + str(non_eitc_disparity_unwgt))
print("Weighted: " + str(non_eitc_disparity_wgt))
print("N: " + str(non_eitc_n))

sys.stdout = sys.__stdout__